"""
Phase 7: Capital Account Openness — Does Openness Amplify Demographic Risk?
=============================================================================
Dedicated analysis of KAOPEN interaction with demographics for crisis
prediction and CA reversals. Three approaches:
  1. Z × KAOPEN interactions (continuous)
  2. Sample split by KAOPEN tercile (low / mid / high)
  3. Split-sample for both banking crisis and CA reversal
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = PROJECT_DIR.parent
sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

DATA = PROJECT_DIR / "data" / "processed"
OUT_TABLES = PROJECT_DIR / "output" / "tables"
OUT_TABLES.mkdir(parents=True, exist_ok=True)


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.1: return '*'
    return ''


def fmt(val, se, p):
    s = stars(p)
    return f"{val:.4f}{s}", f"({se:.4f})"


def run_gls(df, y_var, x_vars, label):
    """Run PanelGLS and return results dict."""
    cols = [y_var] + x_vars + ['iso3', 'year']
    sub = df[cols].dropna()
    if len(sub) < 50:
        print(f"  {label}: insufficient obs ({len(sub)}), skipping")
        return None

    gls = PanelGLS()
    y = sub[y_var].values
    X = sub[x_vars].values
    try:
        gls.fit(y, X, sub['iso3'].values, sub['year'].values)
    except Exception as e:
        print(f"  {label}: GLS failed ({e}), skipping")
        return None

    result = {
        'model': label,
        'n_obs': gls.n_obs,
        'n_countries': gls.n_countries,
        'r_squared': gls.r_squared,
        'rho': gls.rho,
    }
    print(f"\n  {label} (N={gls.n_obs}, R²={gls.r_squared:.4f})")
    for i, name in enumerate(x_vars):
        sig = stars(gls.pvalues[i])
        print(f"    {name:30s} {gls.beta[i]:8.4f} ({gls.se[i]:.4f}) {sig}")
        result[f'{name}_coef'] = gls.beta[i]
        result[f'{name}_se'] = gls.se[i]
        result[f'{name}_p'] = gls.pvalues[i]

    return result


def write_table(results, filename, title, note=None):
    """Write regression results as markdown table."""
    if not results:
        return

    lines = [f"# {title}\n"]

    all_vars = []
    for r in results:
        for k in r:
            if k.endswith('_coef'):
                vname = k.replace('_coef', '')
                if vname not in all_vars:
                    all_vars.append(vname)

    model_labels = [r['model'] for r in results]
    header = "| Variable | " + " | ".join(model_labels) + " |"
    sep = "|:---|" + "|".join(["---:" for _ in results]) + "|"
    lines.append(header)
    lines.append(sep)

    for var in all_vars:
        coef_row = f"| {var} |"
        se_row = "| |"
        for r in results:
            if f'{var}_coef' in r:
                c, s = fmt(r[f'{var}_coef'], r[f'{var}_se'], r[f'{var}_p'])
                coef_row += f" {c} |"
                se_row += f" {s} |"
            else:
                coef_row += " |"
                se_row += " |"
        lines.append(coef_row)
        lines.append(se_row)

    lines.append("|:---|" + "|".join(["---:" for _ in results]) + "|")
    n_row = "| N |"
    r2_row = "| R² |"
    nc_row = "| Countries |"
    for r in results:
        n_row += f" {r['n_obs']} |"
        r2_row += f" {r['r_squared']:.4f} |"
        nc_row += f" {r['n_countries']} |"
    lines.append(n_row)
    lines.append(r2_row)
    lines.append(nc_row)

    if note:
        lines.append(f"\n{note}")
    else:
        lines.append("\n*Panel GLS with country and year fixed effects. "
                     "Standard errors in parentheses.*")
        lines.append("*\\*p<0.1, \\*\\*p<0.05, \\*\\*\\*p<0.01*")

    path = OUT_TABLES / filename
    path.write_text('\n'.join(lines))
    print(f"\n  Saved: {path}")


def main():
    print("=" * 70)
    print("PHASE 7: CAPITAL ACCOUNT OPENNESS AND DEMOGRAPHIC CRISIS RISK")
    print("=" * 70)

    df = pd.read_csv(DATA / "crises_panel.csv")
    print(f"Panel: {len(df)} obs, {df['iso3'].nunique()} countries")

    controls = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth']

    # ── KAOPEN terciles ──
    kaopen_valid = df['kaopen'].dropna()
    t1 = kaopen_valid.quantile(1/3)
    t2 = kaopen_valid.quantile(2/3)
    df['kaopen_tercile'] = pd.cut(df['kaopen'], bins=[-np.inf, t1, t2, np.inf],
                                   labels=['Closed', 'Mid', 'Open'])
    print(f"\n  KAOPEN tercile cutoffs: ≤{t1:.2f} (Closed), {t1:.2f}–{t2:.2f} (Mid), ≥{t2:.2f} (Open)")
    for t in ['Closed', 'Mid', 'Open']:
        sub = df[df['kaopen_tercile'] == t]
        print(f"    {t}: {sub['iso3'].nunique()} countries, {len(sub)} obs, "
              f"banking onsets={sub['banking_crisis_onset'].sum():.0f}, "
              f"CA reversals={sub['ca_reversal'].sum():.0f}")

    # ══════════════════════════════════════════════════════════════════
    # TABLE A: Banking Crisis — KAOPEN interaction + tercile split
    # ══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 60)
    print("TABLE A: BANKING CRISIS × KAOPEN")
    print("=" * 60)

    results_banking = []

    # (1) Full sample with Z only
    r = run_gls(df, 'banking_crisis_onset',
                ['Z_1', 'Z_2', 'Z_3'] + controls + ['kaopen'],
                'Full Sample')
    if r: results_banking.append(r)

    # (2) Full sample with Z × KAOPEN interaction
    interact_vars = ['Z_1', 'Z_2', 'Z_3',
                     'Z_1_x_kaopen', 'Z_2_x_kaopen', 'Z_3_x_kaopen']
    available = [v for v in interact_vars if v in df.columns]
    r = run_gls(df, 'banking_crisis_onset',
                available + controls + ['kaopen'],
                'Interaction')
    if r: results_banking.append(r)

    # (3-5) By KAOPEN tercile
    for tercile in ['Closed', 'Mid', 'Open']:
        sub = df[df['kaopen_tercile'] == tercile].copy()
        r = run_gls(sub, 'banking_crisis_onset',
                    ['Z_1', 'Z_2', 'Z_3'] + controls,
                    f'KAOPEN: {tercile}')
        if r: results_banking.append(r)

    write_table(results_banking, "kaopen_banking.md",
                "Capital Account Openness and Banking Crisis Prediction",
                note=("*Panel GLS with country and year fixed effects. "
                      "Standard errors in parentheses. "
                      "KAOPEN terciles: Closed (≤33rd pctile), Mid, Open (≥67th pctile).*\n"
                      "*\\*p<0.1, \\*\\*p<0.05, \\*\\*\\*p<0.01*"))

    # ══════════════════════════════════════════════════════════════════
    # TABLE B: CA Reversal — KAOPEN interaction + tercile split
    # ══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 60)
    print("TABLE B: CA REVERSAL × KAOPEN")
    print("=" * 60)

    results_reversal = []

    # (1) Full sample
    r = run_gls(df, 'ca_reversal',
                ['Z_1', 'Z_2', 'Z_3'] + controls + ['kaopen'],
                'Full Sample')
    if r: results_reversal.append(r)

    # (2) Interaction
    r = run_gls(df, 'ca_reversal',
                available + controls + ['kaopen'],
                'Interaction')
    if r: results_reversal.append(r)

    # (3-5) By tercile
    for tercile in ['Closed', 'Mid', 'Open']:
        sub = df[df['kaopen_tercile'] == tercile].copy()
        r = run_gls(sub, 'ca_reversal',
                    ['Z_1', 'Z_2', 'Z_3'] + controls,
                    f'KAOPEN: {tercile}')
        if r: results_reversal.append(r)

    write_table(results_reversal, "kaopen_reversal.md",
                "Capital Account Openness and CA Reversal Prediction",
                note=("*Panel GLS with country and year fixed effects. "
                      "Standard errors in parentheses. "
                      "KAOPEN terciles: Closed (≤33rd pctile), Mid, Open (≥67th pctile).*\n"
                      "*\\*p<0.1, \\*\\*p<0.05, \\*\\*\\*p<0.01*"))

    # ══════════════════════════════════════════════════════════════════
    # TABLE C: Youth + Aging channels by KAOPEN tercile
    # ══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 60)
    print("TABLE C: YOUTH/AGING CHANNELS × KAOPEN")
    print("=" * 60)

    results_channels = []

    for tercile in ['Closed', 'Mid', 'Open']:
        sub = df[df['kaopen_tercile'] == tercile].copy()

        # Banking crisis
        r = run_gls(sub, 'banking_crisis_onset',
                    ['youth_dep', 'old_dep'] + controls,
                    f'Banking: {tercile}')
        if r: results_channels.append(r)

    for tercile in ['Closed', 'Mid', 'Open']:
        sub = df[df['kaopen_tercile'] == tercile].copy()

        # CA reversal
        r = run_gls(sub, 'ca_reversal',
                    ['youth_dep', 'old_dep'] + controls,
                    f'Reversal: {tercile}')
        if r: results_channels.append(r)

    write_table(results_channels, "kaopen_channels.md",
                "Youth and Aging Channels by Capital Account Openness",
                note=("*Panel GLS with country and year fixed effects. "
                      "Standard errors in parentheses. "
                      "Separate regressions by KAOPEN tercile for banking crisis onset "
                      "and CA reversal.*\n"
                      "*\\*p<0.1, \\*\\*p<0.05, \\*\\*\\*p<0.01*"))

    # ══════════════════════════════════════════════════════════════════
    # Summary statistics by KAOPEN tercile
    # ══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 60)
    print("SUMMARY: CRISIS INCIDENCE BY KAOPEN TERCILE")
    print("=" * 60)

    summary_lines = ["# Crisis Incidence by Capital Account Openness\n"]
    summary_lines.append("| | Closed | Mid | Open |")
    summary_lines.append("|:---|---:|---:|---:|")

    for var, label in [('banking_crisis_onset', 'Banking onset rate'),
                       ('any_crisis_onset', 'Any crisis onset rate'),
                       ('ca_reversal', 'CA reversal rate'),
                       ('ca_reversal_5pp', 'CA reversal rate (5pp)')]:
        row = f"| {label} |"
        for tercile in ['Closed', 'Mid', 'Open']:
            sub = df[df['kaopen_tercile'] == tercile]
            rate = sub[var].mean() * 100
            row += f" {rate:.2f}% |"
        summary_lines.append(row)

    for var, label in [('Z_1', 'Mean Z₁'),
                       ('youth_dep', 'Mean youth dep.'),
                       ('old_dep', 'Mean old dep.'),
                       ('nfa_gdp_lag', 'Mean NFA/GDP'),
                       ('ca_gdp', 'Mean CA/GDP')]:
        if var in df.columns:
            row = f"| {label} |"
            for tercile in ['Closed', 'Mid', 'Open']:
                sub = df[df['kaopen_tercile'] == tercile]
                val = sub[var].mean()
                row += f" {val:.3f} |"
            summary_lines.append(row)

    n_row = "| N |"
    nc_row = "| Countries |"
    for tercile in ['Closed', 'Mid', 'Open']:
        sub = df[df['kaopen_tercile'] == tercile]
        n_row += f" {len(sub)} |"
        nc_row += f" {sub['iso3'].nunique()} |"
    summary_lines.append(n_row)
    summary_lines.append(nc_row)

    summary_lines.append("\n*Crisis rates are unconditional sample means (%). "
                         "KAOPEN terciles based on pooled distribution.*")

    path = OUT_TABLES / "kaopen_summary.md"
    path.write_text('\n'.join(summary_lines))
    print(f"\n  Saved: {path}")

    print("\n" + "=" * 70)
    print("PHASE 7 COMPLETE")
    print("=" * 70)


if __name__ == '__main__':
    main()
